Skip to content

Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406

Open
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:fix-qwenimage-rope-sync
Open

Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:fix-qwenimage-rope-sync

Conversation

@akshan-main
Copy link
Copy Markdown

@akshan-main akshan-main commented Apr 3, 2026

What does this PR do?

Part of #13401

QwenEmbedRope.forward() copies pos_freqs and neg_freqs from CPU to GPU via .to(device) on every transformer forward call. These tensors are fixed at init and never change, so the repeated transfer triggers an unnecessary cudaStreamSynchronize (~76ms each).

Added _get_device_freqs() that caches the GPU copy on first call. Applied to both QwenEmbedRope and QwenEmbedLayer3DRope.

(register_buffer can't be used here because it drops the imaginary part of complex tensors)

Profiling (A100 80GB, eager, 2 steps, 1024x1024)

                                     BEFORE        AFTER
------------------------------ ------------ ------------
Big syncs (>50ms)                         3            0
Big sync total (ms)                   228.7          0.0
Big syncs before: [76.6, 76.4, 75.7]
Big syncs after:  []

Before (76ms cudaStreamSynchronize inside transformer_forward):

before_sync

After (no sync gap):

after_sync

Profiled with the tooling from #13356. Reproduction notebook.

Part of #13401

Before submitting

Who can review?

@sayakpaul @dg845

@akshan-main
Copy link
Copy Markdown
Author

akshan-main commented Apr 3, 2026

The profiling was done with 2 steps, but this sync happens every transformer forward call, so at 20 inference steps, this eliminates ~1.5s of CPU-GPU sync overhead per run. Under torch.compile the impact is larger since GPU queues are deeper(each sync stalls longer) (80ms vs 76ms in eager).

@akshan-main
Copy link
Copy Markdown
Author

oh and this fix applies to all QwenImage variants (Edit, EditPlus, Layered) since they share the same transformer

@dg845 dg845 requested review from dg845 and sayakpaul April 8, 2026 05:39
@sayakpaul
Copy link
Copy Markdown
Member

@akshan-main thanks for this! In the second plot, could you tell which one of the blocks the reported duration belongs to?

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a clean fix to me. But I will let @dg845 make the final merge.

@akshan-main
Copy link
Copy Markdown
Author

the selected slice in after image is the transformer_forward user_annotation itself (~439ms), wrapping the full QwenImageTransformer2DModel.forward.

I am highlighting a specific sub-block showing where the 76ms cudaStreamSynchronize used to sit (in the before screenshot) is gone.

@akshan-main
Copy link
Copy Markdown
Author

~439ms is for entire transformer_forward block

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants